package cn.wen.rpc.spring;

import cn.wen.rpc.annotation.EnableRpc;
import cn.wen.rpc.annotation.WenReference;
import cn.wen.rpc.annotation.WenService;
import cn.wen.rpc.config.WenRpcConfig;
import cn.wen.rpc.factory.SingletonFactory;
import cn.wen.rpc.netty.client.NettyClient;
import cn.wen.rpc.proxy.WenRpcClientProxy;
import cn.wen.rpc.register.nacos.NacosTemplate;
import cn.wen.rpc.server.WenServiceProvider;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.util.ClassUtils;

import java.lang.reflect.Field;
import java.lang.reflect.Method;

/**
 * 在spring 的bean 初始化 前后进行调用,一般代码都写到 初始化之后
 */
@Slf4j
public class WenRpcSpringBeanPostProcessor implements BeanPostProcessor, BeanFactoryPostProcessor {

    private WenServiceProvider wenServiceProvider;

    private WenRpcConfig wenRpcConfig;
    private NettyClient nettyClient;
    private NacosTemplate nacosTemplate;

    public WenRpcSpringBeanPostProcessor(){
        // 1. 防止线程问题 2. 便于其他类使用
        wenServiceProvider = SingletonFactory.getInstance(WenServiceProvider.class);
        nettyClient = SingletonFactory.getInstance(NettyClient.class);
        nacosTemplate = SingletonFactory.getInstance(NacosTemplate.class);
    }

    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        EnableRpc enableRpc = bean.getClass().getAnnotation(EnableRpc.class);
        if (enableRpc != null){
            if (wenRpcConfig == null) {
                log.info("EnableRpc 会先于所有的Bean实例化之前 执行");
                wenRpcConfig = new WenRpcConfig();
                wenRpcConfig.setProviderPort(enableRpc.serverPort());
                wenRpcConfig.setNacosPort(enableRpc.nacosPort());
                wenRpcConfig.setNacosHost(enableRpc.nacosHost());
                wenRpcConfig.setNacosGroup(enableRpc.nacosGroup());
                nettyClient.setWenRpcConfig(wenRpcConfig);
                wenServiceProvider.setWenRpcConfig(wenRpcConfig);
                nacosTemplate.init(wenRpcConfig.getNacosHost(),wenRpcConfig.getNacosPort());

            }
        }
        return bean;
    }

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        // 找到WenService注解，以及 WenReference注解
        // bean代表spring的所有能扫描到的bean
        if (bean.getClass().isAnnotationPresent(WenService.class)){
            WenService wenService = bean.getClass().getAnnotation(WenService.class);
            // 加了WenService的bean就被找到了，就把其中的方法 都发布为服务
            wenServiceProvider.publishService(wenService,bean);
        }
        Field[] declaredFields = bean.getClass().getDeclaredFields();
        for (Field declaredField : declaredFields) {
            WenReference wenReference = declaredField.getAnnotation(WenReference.class);
            if (wenReference != null){
                // 找到了加了WenReference的字段，就要生成代理类，当接口方法调用的时候，实际上就是访问的代理类
                // 中的invoke方法
                WenRpcClientProxy wenRpcClientProxy = new WenRpcClientProxy(wenReference,nettyClient);
                Object proxy = wenRpcClientProxy.getProxy(declaredField.getType());
                // 当isAccessible()的结果是false时不允许通过反射访问该字段
                declaredField.setAccessible(true);
                try {
                    declaredField.set(bean, proxy);
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                }
            }
        }
        return bean;
    }

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
        if (beanFactory instanceof BeanDefinitionRegistry) {
            try {
                // init scanner
                Class<?> scannerClass = ClassUtils.forName ( "org.springframework.context.annotation.ClassPathBeanDefinitionScanner",
                        WenRpcSpringBeanPostProcessor.class.getClassLoader () );
                Object scanner = scannerClass.getConstructor ( new Class<?>[]{BeanDefinitionRegistry.class, boolean.class} )
                        .newInstance ( new Object[]{(BeanDefinitionRegistry) beanFactory, true} );
                // add filter
                Class<?> filterClass = ClassUtils.forName ( "org.springframework.core.type.filter.AnnotationTypeFilter",
                        WenRpcSpringBeanPostProcessor.class.getClassLoader () );
                Object filter = filterClass.getConstructor ( Class.class ).newInstance ( EnableRpc.class );
                Method addIncludeFilter = scannerClass.getMethod ( "addIncludeFilter",
                        ClassUtils.forName ( "org.springframework.core.type.filter.TypeFilter", WenRpcSpringBeanPostProcessor.class.getClassLoader () ) );
                addIncludeFilter.invoke ( scanner, filter );
                // scan packages
                Method scan = scannerClass.getMethod ( "scan", new Class<?>[]{String[].class} );
                scan.invoke ( scanner, new Object[]{"cn.wen.rpc.annotation"} );
            } catch (Throwable e) {
                // spring 2.0
            }
        }
    }
}
